{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cot import Collection\n",
    "from cot.stats import evaluation_as_table\n",
    "from numpy import loadtxt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll = Collection.load_thoughtsource_100()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select just our own created COTs\n",
    "coll.select_generated_cots(author=\"thoughtsource\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/kon/work/ThoughtSource/libs/cot/cot/stats.py:406: PerformanceWarning: indexing past lexsort depth may impact performance.\n",
      "  df.loc[dataset, (instruction + \"_\" + cot_trigger, model)] = v\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_724f2_row0_col1, #T_724f2_row1_col15, #T_724f2_row2_col15, #T_724f2_row3_col15, #T_724f2_row4_col15, #T_724f2_row5_col3, #T_724f2_row6_col15 {\n",
       "  font-weight: bold;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_724f2\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_724f2_level0_col0\" class=\"col_heading level0 col0\" colspan=\"6\">None_None</th>\n",
       "      <th id=\"T_724f2_level0_col6\" class=\"col_heading level0 col6\" colspan=\"5\">None_kojima-01</th>\n",
       "      <th id=\"T_724f2_level0_col11\" class=\"col_heading level0 col11\" >None_kojima-03</th>\n",
       "      <th id=\"T_724f2_level0_col12\" class=\"col_heading level0 col12\" >None_kojima-09</th>\n",
       "      <th id=\"T_724f2_level0_col13\" class=\"col_heading level0 col13\" colspan=\"3\">None_zhou-01</th>\n",
       "      <th id=\"T_724f2_level0_col16\" class=\"col_heading level0 col16\" >qa-01_None</th>\n",
       "      <th id=\"T_724f2_level0_col17\" class=\"col_heading level0 col17\" >qa-05_None</th>\n",
       "      <th id=\"T_724f2_level0_col18\" class=\"col_heading level0 col18\" >qa-08_None</th>\n",
       "      <th id=\"T_724f2_level0_col19\" class=\"col_heading level0 col19\" >qa-09_None</th>\n",
       "      <th id=\"T_724f2_level0_col20\" class=\"col_heading level0 col20\" >qa-10_None</th>\n",
       "      <th id=\"T_724f2_level0_col21\" class=\"col_heading level0 col21\" >qa-12_None</th>\n",
       "      <th id=\"T_724f2_level0_col22\" class=\"col_heading level0 col22\" >qa-13_None</th>\n",
       "      <th id=\"T_724f2_level0_col23\" class=\"col_heading level0 col23\" >qa-16_None</th>\n",
       "      <th id=\"T_724f2_level0_col24\" class=\"col_heading level0 col24\" >qa-17_None</th>\n",
       "      <th id=\"T_724f2_level0_col25\" class=\"col_heading level0 col25\" >zhou-01-ins_None</th>\n",
       "      <th id=\"T_724f2_level0_col26\" class=\"col_heading level0 col26\" >zhou-01-ins_zhou-01</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th class=\"blank level1\" >&nbsp;</th>\n",
       "      <th id=\"T_724f2_level1_col0\" class=\"col_heading level1 col0\" >command-xlarge-nightly</th>\n",
       "      <th id=\"T_724f2_level1_col1\" class=\"col_heading level1 col1\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_724f2_level1_col2\" class=\"col_heading level1 col2\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col3\" class=\"col_heading level1 col3\" >gpt-4</th>\n",
       "      <th id=\"T_724f2_level1_col4\" class=\"col_heading level1 col4\" >text-davinci-002</th>\n",
       "      <th id=\"T_724f2_level1_col5\" class=\"col_heading level1 col5\" >text-davinci-003</th>\n",
       "      <th id=\"T_724f2_level1_col6\" class=\"col_heading level1 col6\" >command-xlarge-nightly</th>\n",
       "      <th id=\"T_724f2_level1_col7\" class=\"col_heading level1 col7\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_724f2_level1_col8\" class=\"col_heading level1 col8\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col9\" class=\"col_heading level1 col9\" >text-davinci-002</th>\n",
       "      <th id=\"T_724f2_level1_col10\" class=\"col_heading level1 col10\" >text-davinci-003</th>\n",
       "      <th id=\"T_724f2_level1_col11\" class=\"col_heading level1 col11\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col12\" class=\"col_heading level1 col12\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col13\" class=\"col_heading level1 col13\" >flan-T5-xxl</th>\n",
       "      <th id=\"T_724f2_level1_col14\" class=\"col_heading level1 col14\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col15\" class=\"col_heading level1 col15\" >gpt-4</th>\n",
       "      <th id=\"T_724f2_level1_col16\" class=\"col_heading level1 col16\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col17\" class=\"col_heading level1 col17\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col18\" class=\"col_heading level1 col18\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col19\" class=\"col_heading level1 col19\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col20\" class=\"col_heading level1 col20\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col21\" class=\"col_heading level1 col21\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col22\" class=\"col_heading level1 col22\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col23\" class=\"col_heading level1 col23\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col24\" class=\"col_heading level1 col24\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col25\" class=\"col_heading level1 col25\" >gpt-3.5-turbo</th>\n",
       "      <th id=\"T_724f2_level1_col26\" class=\"col_heading level1 col26\" >gpt-3.5-turbo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row0\" class=\"row_heading level0 row0\" >commonsense_qa</th>\n",
       "      <td id=\"T_724f2_row0_col0\" class=\"data row0 col0\" >0.51</td>\n",
       "      <td id=\"T_724f2_row0_col1\" class=\"data row0 col1\" >0.87</td>\n",
       "      <td id=\"T_724f2_row0_col2\" class=\"data row0 col2\" >0.72</td>\n",
       "      <td id=\"T_724f2_row0_col3\" class=\"data row0 col3\" >0.75</td>\n",
       "      <td id=\"T_724f2_row0_col4\" class=\"data row0 col4\" >0.76</td>\n",
       "      <td id=\"T_724f2_row0_col5\" class=\"data row0 col5\" >0.72</td>\n",
       "      <td id=\"T_724f2_row0_col6\" class=\"data row0 col6\" >0.53</td>\n",
       "      <td id=\"T_724f2_row0_col7\" class=\"data row0 col7\" >0.81</td>\n",
       "      <td id=\"T_724f2_row0_col8\" class=\"data row0 col8\" >0.67</td>\n",
       "      <td id=\"T_724f2_row0_col9\" class=\"data row0 col9\" >0.62</td>\n",
       "      <td id=\"T_724f2_row0_col10\" class=\"data row0 col10\" >0.65</td>\n",
       "      <td id=\"T_724f2_row0_col11\" class=\"data row0 col11\" >0.63</td>\n",
       "      <td id=\"T_724f2_row0_col12\" class=\"data row0 col12\" >0.70</td>\n",
       "      <td id=\"T_724f2_row0_col13\" class=\"data row0 col13\" >0.83</td>\n",
       "      <td id=\"T_724f2_row0_col14\" class=\"data row0 col14\" >0.66</td>\n",
       "      <td id=\"T_724f2_row0_col15\" class=\"data row0 col15\" >0.72</td>\n",
       "      <td id=\"T_724f2_row0_col16\" class=\"data row0 col16\" >0.66</td>\n",
       "      <td id=\"T_724f2_row0_col17\" class=\"data row0 col17\" >0.69</td>\n",
       "      <td id=\"T_724f2_row0_col18\" class=\"data row0 col18\" >0.62</td>\n",
       "      <td id=\"T_724f2_row0_col19\" class=\"data row0 col19\" >0.64</td>\n",
       "      <td id=\"T_724f2_row0_col20\" class=\"data row0 col20\" >0.68</td>\n",
       "      <td id=\"T_724f2_row0_col21\" class=\"data row0 col21\" >0.63</td>\n",
       "      <td id=\"T_724f2_row0_col22\" class=\"data row0 col22\" >0.61</td>\n",
       "      <td id=\"T_724f2_row0_col23\" class=\"data row0 col23\" >0.58</td>\n",
       "      <td id=\"T_724f2_row0_col24\" class=\"data row0 col24\" >0.66</td>\n",
       "      <td id=\"T_724f2_row0_col25\" class=\"data row0 col25\" >0.72</td>\n",
       "      <td id=\"T_724f2_row0_col26\" class=\"data row0 col26\" >0.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row1\" class=\"row_heading level0 row1\" >med_qa</th>\n",
       "      <td id=\"T_724f2_row1_col0\" class=\"data row1 col0\" >0.23</td>\n",
       "      <td id=\"T_724f2_row1_col1\" class=\"data row1 col1\" >0.32</td>\n",
       "      <td id=\"T_724f2_row1_col2\" class=\"data row1 col2\" >0.58</td>\n",
       "      <td id=\"T_724f2_row1_col3\" class=\"data row1 col3\" >0.73</td>\n",
       "      <td id=\"T_724f2_row1_col4\" class=\"data row1 col4\" >0.41</td>\n",
       "      <td id=\"T_724f2_row1_col5\" class=\"data row1 col5\" >0.43</td>\n",
       "      <td id=\"T_724f2_row1_col6\" class=\"data row1 col6\" >0.31</td>\n",
       "      <td id=\"T_724f2_row1_col7\" class=\"data row1 col7\" >0.34</td>\n",
       "      <td id=\"T_724f2_row1_col8\" class=\"data row1 col8\" >0.59</td>\n",
       "      <td id=\"T_724f2_row1_col9\" class=\"data row1 col9\" >0.34</td>\n",
       "      <td id=\"T_724f2_row1_col10\" class=\"data row1 col10\" >0.43</td>\n",
       "      <td id=\"T_724f2_row1_col11\" class=\"data row1 col11\" >0.59</td>\n",
       "      <td id=\"T_724f2_row1_col12\" class=\"data row1 col12\" >0.51</td>\n",
       "      <td id=\"T_724f2_row1_col13\" class=\"data row1 col13\" >0.27</td>\n",
       "      <td id=\"T_724f2_row1_col14\" class=\"data row1 col14\" >0.65</td>\n",
       "      <td id=\"T_724f2_row1_col15\" class=\"data row1 col15\" >0.76</td>\n",
       "      <td id=\"T_724f2_row1_col16\" class=\"data row1 col16\" >0.54</td>\n",
       "      <td id=\"T_724f2_row1_col17\" class=\"data row1 col17\" >0.53</td>\n",
       "      <td id=\"T_724f2_row1_col18\" class=\"data row1 col18\" >0.46</td>\n",
       "      <td id=\"T_724f2_row1_col19\" class=\"data row1 col19\" >0.55</td>\n",
       "      <td id=\"T_724f2_row1_col20\" class=\"data row1 col20\" >0.56</td>\n",
       "      <td id=\"T_724f2_row1_col21\" class=\"data row1 col21\" >0.59</td>\n",
       "      <td id=\"T_724f2_row1_col22\" class=\"data row1 col22\" >0.49</td>\n",
       "      <td id=\"T_724f2_row1_col23\" class=\"data row1 col23\" >0.60</td>\n",
       "      <td id=\"T_724f2_row1_col24\" class=\"data row1 col24\" >0.56</td>\n",
       "      <td id=\"T_724f2_row1_col25\" class=\"data row1 col25\" >0.54</td>\n",
       "      <td id=\"T_724f2_row1_col26\" class=\"data row1 col26\" >0.52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row2\" class=\"row_heading level0 row2\" >medmc_qa</th>\n",
       "      <td id=\"T_724f2_row2_col0\" class=\"data row2 col0\" >0.25</td>\n",
       "      <td id=\"T_724f2_row2_col1\" class=\"data row2 col1\" >0.34</td>\n",
       "      <td id=\"T_724f2_row2_col2\" class=\"data row2 col2\" >0.58</td>\n",
       "      <td id=\"T_724f2_row2_col3\" class=\"data row2 col3\" >0.69</td>\n",
       "      <td id=\"T_724f2_row2_col4\" class=\"data row2 col4\" >0.34</td>\n",
       "      <td id=\"T_724f2_row2_col5\" class=\"data row2 col5\" >0.40</td>\n",
       "      <td id=\"T_724f2_row2_col6\" class=\"data row2 col6\" >0.22</td>\n",
       "      <td id=\"T_724f2_row2_col7\" class=\"data row2 col7\" >0.35</td>\n",
       "      <td id=\"T_724f2_row2_col8\" class=\"data row2 col8\" >0.47</td>\n",
       "      <td id=\"T_724f2_row2_col9\" class=\"data row2 col9\" >0.34</td>\n",
       "      <td id=\"T_724f2_row2_col10\" class=\"data row2 col10\" >0.36</td>\n",
       "      <td id=\"T_724f2_row2_col11\" class=\"data row2 col11\" >0.50</td>\n",
       "      <td id=\"T_724f2_row2_col12\" class=\"data row2 col12\" >0.50</td>\n",
       "      <td id=\"T_724f2_row2_col13\" class=\"data row2 col13\" >0.31</td>\n",
       "      <td id=\"T_724f2_row2_col14\" class=\"data row2 col14\" >0.48</td>\n",
       "      <td id=\"T_724f2_row2_col15\" class=\"data row2 col15\" >0.70</td>\n",
       "      <td id=\"T_724f2_row2_col16\" class=\"data row2 col16\" >0.47</td>\n",
       "      <td id=\"T_724f2_row2_col17\" class=\"data row2 col17\" >0.42</td>\n",
       "      <td id=\"T_724f2_row2_col18\" class=\"data row2 col18\" >0.45</td>\n",
       "      <td id=\"T_724f2_row2_col19\" class=\"data row2 col19\" >0.47</td>\n",
       "      <td id=\"T_724f2_row2_col20\" class=\"data row2 col20\" >0.49</td>\n",
       "      <td id=\"T_724f2_row2_col21\" class=\"data row2 col21\" >0.53</td>\n",
       "      <td id=\"T_724f2_row2_col22\" class=\"data row2 col22\" >0.41</td>\n",
       "      <td id=\"T_724f2_row2_col23\" class=\"data row2 col23\" >0.48</td>\n",
       "      <td id=\"T_724f2_row2_col24\" class=\"data row2 col24\" >0.53</td>\n",
       "      <td id=\"T_724f2_row2_col25\" class=\"data row2 col25\" >0.44</td>\n",
       "      <td id=\"T_724f2_row2_col26\" class=\"data row2 col26\" >0.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row3\" class=\"row_heading level0 row3\" >open_book_qa</th>\n",
       "      <td id=\"T_724f2_row3_col0\" class=\"data row3 col0\" >0.59</td>\n",
       "      <td id=\"T_724f2_row3_col1\" class=\"data row3 col1\" >0.82</td>\n",
       "      <td id=\"T_724f2_row3_col2\" class=\"data row3 col2\" >0.77</td>\n",
       "      <td id=\"T_724f2_row3_col3\" class=\"data row3 col3\" >0.92</td>\n",
       "      <td id=\"T_724f2_row3_col4\" class=\"data row3 col4\" >0.67</td>\n",
       "      <td id=\"T_724f2_row3_col5\" class=\"data row3 col5\" >0.70</td>\n",
       "      <td id=\"T_724f2_row3_col6\" class=\"data row3 col6\" >0.38</td>\n",
       "      <td id=\"T_724f2_row3_col7\" class=\"data row3 col7\" >0.79</td>\n",
       "      <td id=\"T_724f2_row3_col8\" class=\"data row3 col8\" >0.77</td>\n",
       "      <td id=\"T_724f2_row3_col9\" class=\"data row3 col9\" >0.57</td>\n",
       "      <td id=\"T_724f2_row3_col10\" class=\"data row3 col10\" >0.67</td>\n",
       "      <td id=\"T_724f2_row3_col11\" class=\"data row3 col11\" >0.73</td>\n",
       "      <td id=\"T_724f2_row3_col12\" class=\"data row3 col12\" >0.73</td>\n",
       "      <td id=\"T_724f2_row3_col13\" class=\"data row3 col13\" >0.81</td>\n",
       "      <td id=\"T_724f2_row3_col14\" class=\"data row3 col14\" >0.81</td>\n",
       "      <td id=\"T_724f2_row3_col15\" class=\"data row3 col15\" >0.95</td>\n",
       "      <td id=\"T_724f2_row3_col16\" class=\"data row3 col16\" >0.73</td>\n",
       "      <td id=\"T_724f2_row3_col17\" class=\"data row3 col17\" >0.65</td>\n",
       "      <td id=\"T_724f2_row3_col18\" class=\"data row3 col18\" >0.73</td>\n",
       "      <td id=\"T_724f2_row3_col19\" class=\"data row3 col19\" >0.71</td>\n",
       "      <td id=\"T_724f2_row3_col20\" class=\"data row3 col20\" >0.73</td>\n",
       "      <td id=\"T_724f2_row3_col21\" class=\"data row3 col21\" >0.80</td>\n",
       "      <td id=\"T_724f2_row3_col22\" class=\"data row3 col22\" >0.72</td>\n",
       "      <td id=\"T_724f2_row3_col23\" class=\"data row3 col23\" >0.69</td>\n",
       "      <td id=\"T_724f2_row3_col24\" class=\"data row3 col24\" >0.69</td>\n",
       "      <td id=\"T_724f2_row3_col25\" class=\"data row3 col25\" >0.76</td>\n",
       "      <td id=\"T_724f2_row3_col26\" class=\"data row3 col26\" >0.74</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row4\" class=\"row_heading level0 row4\" >strategy_qa</th>\n",
       "      <td id=\"T_724f2_row4_col0\" class=\"data row4 col0\" >0.52</td>\n",
       "      <td id=\"T_724f2_row4_col1\" class=\"data row4 col1\" >0.61</td>\n",
       "      <td id=\"T_724f2_row4_col2\" class=\"data row4 col2\" >0.57</td>\n",
       "      <td id=\"T_724f2_row4_col3\" class=\"data row4 col3\" >0.71</td>\n",
       "      <td id=\"T_724f2_row4_col4\" class=\"data row4 col4\" >0.36</td>\n",
       "      <td id=\"T_724f2_row4_col5\" class=\"data row4 col5\" >0.53</td>\n",
       "      <td id=\"T_724f2_row4_col6\" class=\"data row4 col6\" >0.59</td>\n",
       "      <td id=\"T_724f2_row4_col7\" class=\"data row4 col7\" >0.69</td>\n",
       "      <td id=\"T_724f2_row4_col8\" class=\"data row4 col8\" >0.56</td>\n",
       "      <td id=\"T_724f2_row4_col9\" class=\"data row4 col9\" >0.38</td>\n",
       "      <td id=\"T_724f2_row4_col10\" class=\"data row4 col10\" >0.54</td>\n",
       "      <td id=\"T_724f2_row4_col11\" class=\"data row4 col11\" >0.52</td>\n",
       "      <td id=\"T_724f2_row4_col12\" class=\"data row4 col12\" >0.58</td>\n",
       "      <td id=\"T_724f2_row4_col13\" class=\"data row4 col13\" >0.59</td>\n",
       "      <td id=\"T_724f2_row4_col14\" class=\"data row4 col14\" >0.59</td>\n",
       "      <td id=\"T_724f2_row4_col15\" class=\"data row4 col15\" >0.80</td>\n",
       "      <td id=\"T_724f2_row4_col16\" class=\"data row4 col16\" >0.44</td>\n",
       "      <td id=\"T_724f2_row4_col17\" class=\"data row4 col17\" >0.43</td>\n",
       "      <td id=\"T_724f2_row4_col18\" class=\"data row4 col18\" >0.58</td>\n",
       "      <td id=\"T_724f2_row4_col19\" class=\"data row4 col19\" >0.62</td>\n",
       "      <td id=\"T_724f2_row4_col20\" class=\"data row4 col20\" >0.56</td>\n",
       "      <td id=\"T_724f2_row4_col21\" class=\"data row4 col21\" >0.50</td>\n",
       "      <td id=\"T_724f2_row4_col22\" class=\"data row4 col22\" >0.64</td>\n",
       "      <td id=\"T_724f2_row4_col23\" class=\"data row4 col23\" >0.63</td>\n",
       "      <td id=\"T_724f2_row4_col24\" class=\"data row4 col24\" >0.58</td>\n",
       "      <td id=\"T_724f2_row4_col25\" class=\"data row4 col25\" >0.52</td>\n",
       "      <td id=\"T_724f2_row4_col26\" class=\"data row4 col26\" >0.57</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row5\" class=\"row_heading level0 row5\" >worldtree</th>\n",
       "      <td id=\"T_724f2_row5_col0\" class=\"data row5 col0\" >0.63</td>\n",
       "      <td id=\"T_724f2_row5_col1\" class=\"data row5 col1\" >0.85</td>\n",
       "      <td id=\"T_724f2_row5_col2\" class=\"data row5 col2\" >0.96</td>\n",
       "      <td id=\"T_724f2_row5_col3\" class=\"data row5 col3\" >0.99</td>\n",
       "      <td id=\"T_724f2_row5_col4\" class=\"data row5 col4\" >0.88</td>\n",
       "      <td id=\"T_724f2_row5_col5\" class=\"data row5 col5\" >0.91</td>\n",
       "      <td id=\"T_724f2_row5_col6\" class=\"data row5 col6\" >0.58</td>\n",
       "      <td id=\"T_724f2_row5_col7\" class=\"data row5 col7\" >0.80</td>\n",
       "      <td id=\"T_724f2_row5_col8\" class=\"data row5 col8\" >0.93</td>\n",
       "      <td id=\"T_724f2_row5_col9\" class=\"data row5 col9\" >0.78</td>\n",
       "      <td id=\"T_724f2_row5_col10\" class=\"data row5 col10\" >0.89</td>\n",
       "      <td id=\"T_724f2_row5_col11\" class=\"data row5 col11\" >0.95</td>\n",
       "      <td id=\"T_724f2_row5_col12\" class=\"data row5 col12\" >0.95</td>\n",
       "      <td id=\"T_724f2_row5_col13\" class=\"data row5 col13\" >0.83</td>\n",
       "      <td id=\"T_724f2_row5_col14\" class=\"data row5 col14\" >0.92</td>\n",
       "      <td id=\"T_724f2_row5_col15\" class=\"data row5 col15\" >0.99</td>\n",
       "      <td id=\"T_724f2_row5_col16\" class=\"data row5 col16\" >0.95</td>\n",
       "      <td id=\"T_724f2_row5_col17\" class=\"data row5 col17\" >0.74</td>\n",
       "      <td id=\"T_724f2_row5_col18\" class=\"data row5 col18\" >0.92</td>\n",
       "      <td id=\"T_724f2_row5_col19\" class=\"data row5 col19\" >0.91</td>\n",
       "      <td id=\"T_724f2_row5_col20\" class=\"data row5 col20\" >0.95</td>\n",
       "      <td id=\"T_724f2_row5_col21\" class=\"data row5 col21\" >0.92</td>\n",
       "      <td id=\"T_724f2_row5_col22\" class=\"data row5 col22\" >0.96</td>\n",
       "      <td id=\"T_724f2_row5_col23\" class=\"data row5 col23\" >0.91</td>\n",
       "      <td id=\"T_724f2_row5_col24\" class=\"data row5 col24\" >0.92</td>\n",
       "      <td id=\"T_724f2_row5_col25\" class=\"data row5 col25\" >0.96</td>\n",
       "      <td id=\"T_724f2_row5_col26\" class=\"data row5 col26\" >0.96</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_724f2_level0_row6\" class=\"row_heading level0 row6\" >Average</th>\n",
       "      <td id=\"T_724f2_row6_col0\" class=\"data row6 col0\" >0.46</td>\n",
       "      <td id=\"T_724f2_row6_col1\" class=\"data row6 col1\" >0.64</td>\n",
       "      <td id=\"T_724f2_row6_col2\" class=\"data row6 col2\" >0.70</td>\n",
       "      <td id=\"T_724f2_row6_col3\" class=\"data row6 col3\" >0.80</td>\n",
       "      <td id=\"T_724f2_row6_col4\" class=\"data row6 col4\" >0.57</td>\n",
       "      <td id=\"T_724f2_row6_col5\" class=\"data row6 col5\" >0.62</td>\n",
       "      <td id=\"T_724f2_row6_col6\" class=\"data row6 col6\" >0.44</td>\n",
       "      <td id=\"T_724f2_row6_col7\" class=\"data row6 col7\" >0.63</td>\n",
       "      <td id=\"T_724f2_row6_col8\" class=\"data row6 col8\" >0.66</td>\n",
       "      <td id=\"T_724f2_row6_col9\" class=\"data row6 col9\" >0.50</td>\n",
       "      <td id=\"T_724f2_row6_col10\" class=\"data row6 col10\" >0.59</td>\n",
       "      <td id=\"T_724f2_row6_col11\" class=\"data row6 col11\" >0.65</td>\n",
       "      <td id=\"T_724f2_row6_col12\" class=\"data row6 col12\" >0.66</td>\n",
       "      <td id=\"T_724f2_row6_col13\" class=\"data row6 col13\" >0.61</td>\n",
       "      <td id=\"T_724f2_row6_col14\" class=\"data row6 col14\" >0.68</td>\n",
       "      <td id=\"T_724f2_row6_col15\" class=\"data row6 col15\" >0.82</td>\n",
       "      <td id=\"T_724f2_row6_col16\" class=\"data row6 col16\" >0.63</td>\n",
       "      <td id=\"T_724f2_row6_col17\" class=\"data row6 col17\" >0.58</td>\n",
       "      <td id=\"T_724f2_row6_col18\" class=\"data row6 col18\" >0.63</td>\n",
       "      <td id=\"T_724f2_row6_col19\" class=\"data row6 col19\" >0.65</td>\n",
       "      <td id=\"T_724f2_row6_col20\" class=\"data row6 col20\" >0.66</td>\n",
       "      <td id=\"T_724f2_row6_col21\" class=\"data row6 col21\" >0.66</td>\n",
       "      <td id=\"T_724f2_row6_col22\" class=\"data row6 col22\" >0.64</td>\n",
       "      <td id=\"T_724f2_row6_col23\" class=\"data row6 col23\" >0.65</td>\n",
       "      <td id=\"T_724f2_row6_col24\" class=\"data row6 col24\" >0.66</td>\n",
       "      <td id=\"T_724f2_row6_col25\" class=\"data row6 col25\" >0.66</td>\n",
       "      <td id=\"T_724f2_row6_col26\" class=\"data row6 col26\" >0.64</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x7f98020f3af0>"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# to get the answer_from_choices, we need to evaluate with overwrite=True for now\n",
    "coll.evaluate(overwrite=True)\n",
    "table = evaluation_as_table(eval)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the answer_from_choices, here just from the first of the generated COTs as an example\n",
    "# this function is way to slow...\n",
    "# we just use med_qa for now\n",
    "answer_from_choices = []\n",
    "for i in range(len(coll[\"med_qa\"][\"test\"][\"generated_cot\"])):\n",
    "    answer_from_choices.append(coll[\"med_qa\"][\"test\"][\"generated_cot\"][i][0][\"answers\"][0][\"answer_from_choices\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['D', 'C', 'D', 'C', 'D', 'D', 'D', 'D', 'D', 'E', 'C', 'C', 'B', 'A', 'A', 'C', 'C', 'B', 'E', 'B', 'D', 'D', 'A', 'E', 'E', 'B', 'C', 'B', 'D', 'E', 'D', 'A', 'A', 'C', 'A', 'D', 'D', 'D', 'E', 'D', 'A', 'D', 'B', 'C', 'D', 'D', '', 'E', 'D', 'A', 'B', 'B', 'A', 'B', 'D', 'E', 'D', 'B', 'D', '', 'C', 'D', 'D', '', 'B', 'E', 'D', '', 'D', 'D', 'D', 'B', 'E', '', 'A', 'D', 'B', 'E', 'D', 'D', 'A', 'D', 'C', 'D', 'E', '', 'E', 'C', 'C', 'D', 'D', 'A', 'B', '', 'A', 'B', 'C', 'B', 'B', '']\n"
     ]
    }
   ],
   "source": [
    "print(answer_from_choices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get correct answer choices from the each item of a dataset\n",
    "def find_correct_choice(lists1, lists2):\n",
    "    output = []\n",
    "    for list1, list2 in zip(lists1, lists2):\n",
    "        sub_output = []\n",
    "        for item in list2:\n",
    "            position = list1.index(item) + 1\n",
    "            sub_output.append(chr(ord('A') + position - 1))\n",
    "        output.extend(sub_output)\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['C', 'D', 'C', 'C', 'C', 'D', 'A', 'A', 'E', 'A', 'E', 'C', 'B', 'D', 'C', 'D', 'D', 'A', 'C', 'B', 'D', 'D', 'A', 'D', 'C', 'D', 'C', 'B', 'B', 'D', 'D', 'C', 'E', 'C', 'A', 'A', 'D', 'D', 'A', 'B', 'B', 'A', 'B', 'E', 'D', 'C', 'B', 'E', 'C', 'C', 'B', 'B', 'A', 'B', 'B', 'D', 'E', 'C', 'D', 'B', 'A', 'E', 'C', 'B', 'E', 'D', 'E', 'C', 'C', 'E', 'D', 'B', 'C', 'B', 'B', 'B', 'C', 'B', 'C', 'D', 'C', 'A', 'C', 'D', 'A', 'E', 'A', 'A', 'C', 'E', 'B', 'E', 'C', 'B', 'B', 'A', 'A', 'B', 'B', 'C']\n"
     ]
    }
   ],
   "source": [
    "correct_choices = find_correct_choice(coll[\"med_qa\"][\"test\"][\"choices\"], coll[\"med_qa\"][\"test\"][\"answer\"])\n",
    "print(correct_choices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "import krippendorff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.14325825170025586"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compute metric\n",
    "krippendorff.alpha(reliability_data=[answer_from_choices, correct_choices],level_of_measurement='nominal')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.unload_datasets([\"med_qa\"], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "coll.dump(\"ts_100_med_qa_2.json\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
